from keras import regularizers
from keras import losses
from keras import backend as K
from keras.layers import Dropout, Reshape, Concatenate, Flatten, Bidirectional, Dense, Embedding, Input, Lambda, LSTM, \
    RepeatVector, TimeDistributed
from keras.models import Model
from keras.callbacks import ReduceLROnPlateau, LearningRateScheduler, ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import Adam, RMSprop
import keras
import numpy as np
import os
from sklearn.metrics import precision_score, accuracy_score, precision_recall_fscore_support
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution

#disable_eager_execution()


class MVAE(object):

    def create(self, max_length,max_length_sticker, image_embed_size,image_embed_prob_size,audio_embed_size, edit_embed_size,latent_dim, reg_lambda, impred_lambda, embed_matrix):
        self.encoder = None
        self.decoder = None
        self.impred = None
        self.autoencoder = None
        self.embedding_matrix = embed_matrix
        self.vocab_size = self.embedding_matrix.shape[0]
        self.max_length = max_length
        self.max_length_sticker = max_length_sticker
        self.latent_dim = latent_dim
        self.reg_lambda = reg_lambda
        self.impred_lambda = impred_lambda
        self.image_embed_size = image_embed_size
        self.image_embed_size_prob = image_embed_prob_size
        self.audio_embed_size = audio_embed_size
        self.edit_embed_size = edit_embed_size


        input_txt = Input(shape=(self.max_length,), name='input_txt')
        input_txt_sticker = Input(shape=(self.max_length_sticker,), name='input_txt_sticker')

        input_img = Input(shape=(None,image_embed_size,), name='input_img')
        input_img_prob = Input((image_embed_prob_size,), name='input_img_prob')
        input_audio=Input((audio_embed_size,), name='input_audio')
        input_edit=Input((edit_embed_size,), name='input_edit')

        mse_img_loss, vae_mse_loss,  encoded = self._build_encoder(input_txt, input_txt_sticker,input_img,input_img_prob,input_audio,input_edit)
        self.encoder = Model(inputs=[input_txt,input_txt_sticker, input_img,input_img_prob,input_audio,input_edit], outputs=encoded)

        encoded_input = Input(shape=(self.latent_dim,))
        predicted_outcome = self._build_impred(encoded_input)
        self.impred = Model(encoded_input, predicted_outcome)

        decoded_txt,decoded_txt_sticker , decoded_img,decoded_img_prob, decoded_audio, decoded_edit= self._build_decoder(encoded_input)
        self.decoder = Model(encoded_input, [decoded_txt,decoded_txt_sticker , decoded_img,decoded_img_prob, decoded_audio, decoded_edit])

        decoder_output = self._build_decoder(encoded)

        self.autoencoder = Model(inputs=[input_txt,input_txt_sticker , input_img,input_img_prob,input_audio,input_edit],
                                 outputs=[decoder_output[0], decoder_output[1],decoder_output[2], decoder_output[3],decoder_output[4],decoder_output[5],self._build_impred(encoded)])
        self.autoencoder.compile(optimizer=Adam(1e-5),
                                 loss=['sparse_categorical_crossentropy','sparse_categorical_crossentropy', mse_img_loss, vae_mse_loss,vae_mse_loss,vae_mse_loss,'mean_squared_error'],
                                 metrics=['accuracy'])
        self.get_features = K.function([input_txt,input_txt_sticker,input_img,input_img_prob,input_audio,input_edit], [encoded])
        print(self.autoencoder.summary())

    def _build_encoder(self, input_txt,input_txt_sticker, input_img,input_img_prob, input_audio,input_edit,latent_dim=256):
        txt_embed = Embedding(self.vocab_size, 32, input_length=self.max_length, name='txt_embed', trainable=False,
                              weights=[self.embedding_matrix])(input_txt)

        lstm_txt_1 = Bidirectional(LSTM(32, return_sequences=True, name='lstm_txt_1', activation='relu',
                                        kernel_regularizer=regularizers.l2(self.reg_lambda)), merge_mode='concat')(
            txt_embed)
        lstm_txt_2 = Bidirectional(LSTM(32, return_sequences=False, name='lstm_txt_2', activation='relu',
                                        kernel_regularizer=regularizers.l2(self.reg_lambda)), merge_mode='concat')(
            lstm_txt_1)
        fc_txt = Dense(32, activation='relu', name='dense_txt', kernel_regularizer=regularizers.l2(self.reg_lambda))(
            lstm_txt_2)

        txt_embed_sticker = Embedding(self.vocab_size, 32, input_length=self.max_length_sticker, name='txt_embed_sticker',
                                      trainable=False,
                                      weights=[self.embedding_matrix])(input_txt_sticker)
        lstm_txt_sticker_1 = Bidirectional(LSTM(32, return_sequences=True, name='lstm_txt_sticker_1', activation='relu',
                                        kernel_regularizer=regularizers.l2(self.reg_lambda)), merge_mode='concat')(
            txt_embed_sticker)
        lstm_txt_sticker_2 = Bidirectional(LSTM(32, return_sequences=False, name='lstm_txt_sticker_2', activation='relu',
                                        kernel_regularizer=regularizers.l2(self.reg_lambda)), merge_mode='concat')(
            lstm_txt_sticker_1)
        fc_txt_sticker = Dense(32, activation='relu', name='dense_txt_sticker', kernel_regularizer=regularizers.l2(self.reg_lambda))(
            lstm_txt_sticker_2)


        fc_img_prob_1 = Dense(1024, name='fc_img_prob_1', activation='relu', kernel_regularizer=regularizers.l2(self.reg_lambda))(
            input_img_prob)
        fc_img_prob_2 = Dense(256, name='fc_img_prob_2', activation='relu', kernel_regularizer=regularizers.l2(self.reg_lambda))(
            fc_img_prob_1)

        lstm_img_1=LSTM(4096,return_sequences=False,name='lstm_img_1', activation='relu',input_shape=(13, 4096))(input_img)
        fc_img_1 = Dense(1024, name='fc_img_1', activation='relu',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(
            lstm_img_1)
        fc_img_2 = Dense(256, name='fc_img_2', activation='relu',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(
            fc_img_1)

        fc_audio_1 = Dense(64, name='fc_audio_1', activation='relu', kernel_regularizer=regularizers.l2(self.reg_lambda))(
            input_audio)
        fc_edit_1 = Dense(16, name='fc_edit_1', activation='relu', kernel_regularizer=regularizers.l2(self.reg_lambda))(
            input_edit)

        h = Concatenate(axis=-1, name='concat')([fc_txt,fc_txt_sticker, fc_img_2,fc_img_prob_2,fc_audio_1,fc_edit_1])
        h = Dense(256, name='shared', activation='relu', kernel_regularizer=regularizers.l2(self.reg_lambda))(h)

        def sampling(args):
            z_mean_, z_log_var_ = args
            batch_size = K.shape(z_mean_)[0]
            epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=0.01)
            return z_mean_ + K.exp(0.5 * z_log_var_) * epsilon

        z_mean = Dense(latent_dim, name='z_mean', activation='linear')(h)
        z_log_var = Dense(latent_dim, name='z_log_var', activation='linear')(h)

        def vae_mse_loss(x, x_decoded_mean):
            mse_loss = losses.mse(x, x_decoded_mean)
            kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
            return mse_loss + (1/4)*kl_loss

        def vae_ce_loss(x, x_decoded_mean):
            x = K.flatten(x)
            x_decoded_mean = K.flatten(x_decoded_mean)
            xent_loss = losses.binary_crossentropy(x, x_decoded_mean)
            kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
            return xent_loss + kl_loss

        def mse_img_loss(x, x_decoded_mean):
            x = K.flatten(x)
            x_decoded_mean = K.flatten(x_decoded_mean)
            mse_loss = losses.mse(x, x_decoded_mean)
            kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
            return mse_loss + (1 /4) * kl_loss
        return (
        mse_img_loss, vae_mse_loss, Lambda(sampling, output_shape=(latent_dim,), name='lambda')([z_mean, z_log_var]))

    def _build_decoder(self, encoded):
        dec_fc_txt = Dense(32, name='dec_fc_txt', activation='relu',
                           kernel_regularizer=regularizers.l2(self.reg_lambda))(encoded)
        repeated_context = RepeatVector(self.max_length)(dec_fc_txt)
        dec_lstm_txt_1 = LSTM(32, return_sequences=True, activation='relu', name='dec_lstm_txt_1',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(repeated_context)
        dec_lstm_txt_2 = LSTM(32, return_sequences=True, activation='relu', name='dec_lstm_txt_2',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_lstm_txt_1)
        decoded_txt = TimeDistributed(Dense(self.vocab_size, activation='softmax'), name='decoded_txt')(dec_lstm_txt_2)

        dec_fc_txt_sticker = Dense(32, name='dec_fc_txt_sticker', activation='relu',
                           kernel_regularizer=regularizers.l2(self.reg_lambda))(encoded)
        repeated_context_sticker = RepeatVector(self.max_length_sticker)(dec_fc_txt_sticker)
        dec_lstm_txt_sticker_1 = LSTM(32, return_sequences=True, activation='relu', name='dec_lstm_txt_sticker_1',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(repeated_context_sticker)
        dec_lstm_txt_sticker_2 = LSTM(32, return_sequences=True, activation='relu', name='dec_lstm_txt_sticker_2',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_lstm_txt_sticker_1)
        decoded_txt_sticker = TimeDistributed(Dense(self.vocab_size, activation='softmax'), name='decoded_txt_sticker')(dec_lstm_txt_sticker_2)


        dec_fc_img_prob_1 = Dense(256, name='dec_fc_img_prob_1', activation='relu',
                             kernel_regularizer=regularizers.l2(self.reg_lambda))(encoded)
        dec_fc_img_prob_2 = Dense(1024, name='dec_fc_img_prob_2', activation='relu',
                             kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_fc_img_prob_1)
        decoded_img_prob = Dense(3089, name='decoded_img_prob', activation='relu',
                             kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_fc_img_prob_2)

        dec_fc_img_1 = Dense(256, name='dec_fc_img_1', activation='relu',
                                  kernel_regularizer=regularizers.l2(self.reg_lambda))(encoded)
        dec_fc_img_2 = Dense(1024, name='dec_fc_img_2', activation='relu',
                                  kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_fc_img_1)
        dec_fc_img_3 = Dense(4096, name='dec_fc_img_3', activation='relu',
                                kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_fc_img_2)
        repeated_img = RepeatVector(13)(dec_fc_img_3)
        decoded_lstm_img=LSTM(4096, return_sequences=True, activation='relu', name='dec_lstm_img_1',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(repeated_img)
        decoded_img = TimeDistributed(Dense(4096, activation='relu'), name='decoded_img')(decoded_lstm_img)

        dec_fc_audio_1 = Dense(64, name='dec_fc_audio_1', activation='relu',
                             kernel_regularizer=regularizers.l2(self.reg_lambda))(encoded)
        decoded_audio = Dense(543, name='decoded_audio', activation='relu')(dec_fc_audio_1)

        dec_fc_edit_1 = Dense(16, name='decoded_edit', activation='relu',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(encoded)

        decoded_edit = Dense(157, name='decoded_edit', activation='relu',
                              kernel_regularizer=regularizers.l2(self.reg_lambda))(dec_fc_edit_1)


        return decoded_txt,decoded_txt_sticker, decoded_img,decoded_img_prob,decoded_audio,decoded_edit


def train(sequence_length,sequence_length_sticker, image_embed_size,image_embed_prob_size,audio_embed_size, edit_embed_size,latent_dim, reg_lambda, impred_lambda, path):
    text = np.load('E:\\data_pi\\train_text_d34.npy')
    text_sticker = np.load('E:\\data_pi\\train_text_sticker_d34.npy')
    im = np.load('E:\\data_pi\\train_image_embed_d34.npy')

    im_prob = np.load('E:\\data_pi\\train_image_embed_prob_d34.npy')
    aud=np.load('E:\\data_pi\\train_yamnet_embed_d34.npy')
    edit=np.load('E:\\data_pi\\train_edit_embed_d34.npy')

    res = np.load('E:\\data_pi\\train_label_d34.npy')[:, 2]

    test_text = np.load('E:\\data_pi\\test_text_d34.npy')
    test_text_sticker = np.load('E:\\data_pi\\test_text_sticker_d34.npy')
    test_im = np.load('E:\\data_pi\\test_image_embed_d34.npy')
    test_im_prob = np.load('E:\\data_pi\\test_image_embed_prob_d34.npy')
    test_aud=np.load('E:\\data_pi\\test_yamnet_embed_d34.npy')
    test_edit=np.load('E:\\data_pi\\test_edit_embed_d34.npy')

    test_res = np.load('E:\\data_pi\\test_label_d34.npy')[:, 2]

    embed_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')

    if not os.path.exists(path):
        os.makedirs(path)
    if not os.path.exists(path + '\\tb'):
        os.makedirs(path + '\\tb')
    if not os.path.exists(path + '\\weights'):
        os.makedirs(path + '\\weights')
    tensorboard = TensorBoard(log_dir=path + '\\tb', write_graph=True, write_images=True)
    checkpoint = ModelCheckpoint(path + '\\weights\\{epoch:02d}.hdf5', monitor='loss', verbose=1, save_best_only=True,
                                 mode='auto')
    reduce_lr = ReduceLROnPlateau(monitor='impred_output_loss', factor=0.2, patience=6, min_lr=1e-5)

    model = MVAE()
    model.create(sequence_length,sequence_length_sticker, image_embed_size,image_embed_prob_size,audio_embed_size,edit_embed_size, latent_dim, reg_lambda, impred_lambda, embed_matrix)
    model.autoencoder.fit(x=[text,text_sticker, im,im_prob ,aud,edit],
                          y={'decoded_txt': np.expand_dims(text, -1),'decoded_txt_sticker': np.expand_dims(text_sticker, -1), 'decoded_img': im,'decoded_img_prob': im_prob,'decoded_audio':aud,'decoded_edit':edit, 'impred_output': res},
                          batch_size=32, epochs=300, callbacks=[checkpoint, tensorboard, reduce_lr], shuffle=True,
                          validation_data=([test_text,test_text_sticker, test_im,test_im_prob,test_aud,test_edit],
                                           {'decoded_txt': np.expand_dims(test_text, -1),'decoded_txt_sticker': np.expand_dims(test_text_sticker, -1), 'decoded_img': test_im,'decoded_img_prob': test_im_prob,'decoded_audio':test_aud,'decoded_edit':test_edit,
                                            'impred_output': test_res}))

def save_features(sequence_length, image_embed_size,image_embed_prob_size,audio_embed_size, edit_embed_size,latent_dim, reg_lambda, impred_lambda, path):
     test_text = np.load('E:\\data_pi\\test_text_s4.npy')
     test_im = np.load('E:\\data_pi\\test_image_embed_s4.npy')
     test_aud = np.load('E:\\data_pi\\test_yamnet_embed_s4.npy')
     test_edit = np.load('E:\\data_pi\\test_edit_embed_s4.npy')

     embed_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')

     model = MVAE()
     model.create(sequence_length, image_embed_size,image_embed_prob_size,audio_embed_size,edit_embed_size, latent_dim, reg_lambda, impred_lambda, embed_matrix)
     model.autoencoder.load_weights(path + '\\3\\90.hdf5')

     if not os.path.exists(path + '\\features'):
         os.makedirs(path + '\\features')

     learnt_features = np.array([]).reshape(0, 256)
     for i in range(test_text.shape[0]):
         text_batch = test_text[i:i + 1]
         im_batch = test_im[i:i + 1]
         aud_batch=test_aud[i:i+1]
         edit_batch = test_edit[i:i + 1]
         batch = model.get_features([text_batch, im_batch,aud_batch,edit_batch])[0]
         learnt_features = np.concatenate([learnt_features, batch])
     np.save(path + '\\features\\vae_impred_s4', learnt_features)


def test(sequence_length,sequence_length_sticker, image_embed_size,image_embed_prob_size,audio_embed_size, edit_embed_size,latent_dim, reg_lambda, impred_lambda, path):
    test_text = np.load('E:\\data_pi\\test_text_d34.npy')
    test_text_sticker = np.load('E:\\data_pi\\test_text_sticker_d34.npy')
    test_im = np.load('E:\\data_pi\\test_image_embed_d34.npy')
    test_im_prob = np.load('E:\\data_pi\\test_image_embed_prob_d34.npy')
    test_label = np.load('E:\\data_pi\\test_label_d34.npy')[:, 2]
    test_aud = np.load('E:\\data_pi\\test_yamnet_embed_d34.npy')
    test_edit = np.load('E:\\data_pi\\test_edit_embed_d34.npy')

    embed_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')

    model = MVAE()
    model.create(sequence_length,sequence_length_sticker, image_embed_size,image_embed_prob_size, audio_embed_size, edit_embed_size,latent_dim, reg_lambda, impred_lambda, embed_matrix)
    model.autoencoder.load_weights(path + '\\weights\\90.hdf5')
    for i in range(10):
        pred = model.autoencoder.predict([test_text,test_text_sticker, test_im,test_im_prob,test_aud,test_edit])[-1]
        pred[pred >= 0.5] = 1
        pred[pred < 0.5] = 0
        print(accuracy_score(test_label, pred))
        print(precision_recall_fscore_support(test_label, pred))


if __name__ == '__main__':
    train(20,20,4096,3089,543,157, 256, 0.05, 0.3, 'E:\\models\\vae_impred_0.05_0.3')
    test(20,20, 4096,3089, 543,157,256, 0.05, 0.3, 'E:\\models\\vae_impred_0.05_0.3')
